{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tWhm0JFMPJ5I"
      },
      "source": [
        "Copyright 2024 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RJcqTAlfPQjk"
      },
      "source": [
        "# [OSS] JAX to TFLite with StableHLO Quantization Demonstration for ODML."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cqeGmbO6PPNd"
      },
      "source": [
        "This example shows a JAX Keras reference model converted into a StableHLO module and via `jax2tf`, then quantized in the ODML Converter via the StableHLO Quantizer.\n",
        "\n",
        "Note: This API is experimental and will likely have breakages with other models. Please reach out to [scalable-opt-team@google.com](mailto:scalable-opt-team@google.com) and we will support your use case."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-S0P42BpPSeJ"
      },
      "source": [
        "## StableHLO Quantizer\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FacwMD9MPUew"
      },
      "source": [
        "StableHLO Quantizer is a quantization API to enable ML framework optionality and hardware retargetability."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RXZUHZQoQZOo"
      },
      "outputs": [],
      "source": [
        "!pip uninstall tensorflow --yes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aYz36YEKPYRk"
      },
      "outputs": [],
      "source": [
        "!pip3 install tf-nightly\n",
        "!pip3 install keras-core"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "duab6P-nPZzF"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "print(\"TensorFlow version:\", tf.__version__)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c9JX9RJTPaoW"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "os.environ['KERAS_BACKEND'] = 'jax'\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "from keras_core.applications import ResNet50\n",
        "from jax.experimental import jax2tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rTcHwDPBPchd"
      },
      "outputs": [],
      "source": [
        "input_shape = (1, 224, 224, 3)\n",
        "\n",
        "jax_callable = jax2tf.convert(\n",
        "    ResNet50(\n",
        "      input_shape=input_shape[1:],\n",
        "      pooling='avg',\n",
        "  ).call,\n",
        "    with_gradient=False,\n",
        "    native_serialization=True,\n",
        "    native_serialization_platforms=('cpu',))\n",
        "\n",
        "tf_module = tf.Module()\n",
        "tf_module.f = tf.function(\n",
        "    jax_callable,\n",
        "    autograph=False,\n",
        "    input_signature=[\n",
        "        tf.TensorSpec(input_shape, jnp.float32, 'lhs_operand')\n",
        "    ],\n",
        ")\n",
        "\n",
        "saved_model_dir = '/tmp/saved_model'\n",
        "tf.saved_model.save(tf_module, saved_model_dir)\n",
        "\n",
        "def calibration_dataset():\n",
        "  rng = np.random.default_rng(seed=1235)\n",
        "  for _ in range(2):\n",
        "    yield {\n",
        "        'lhs_operand': rng.uniform(low=-1.0, high=1.0, size=input_shape).astype(\n",
        "            np.float32\n",
        "        )\n",
        "    }\n",
        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
        "converter.target_spec.supported_ops = [\n",
        "    tf.lite.OpsSet.SELECT_TF_OPS,  # enable TensorFlow ops.\n",
        "    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TFL ops.\n",
        "]\n",
        "converter.representative_dataset = calibration_dataset\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "# Below flag controls whether to use StableHLO Quantizer or TFLite quantizer.\n",
        "converter.experimental_use_stablehlo_quantizer = True\n",
        "\n",
        "quantized_model = converter.convert()\n",
        "\n",
        "with open('/tmp/resnet50_quantized.tflite', 'wb') as f:\n",
        "  f.write(quantized_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u3b9Xj8dPdXo"
      },
      "outputs": [],
      "source": [
        "print(str(os.path.getsize('/tmp/resnet50_quantized.tflite') \u003e\u003e 20) + 'MB')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/stablehlo_quantizer_odml_oss.ipynb",
          "timestamp": 1712841250910
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
